{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feathr Feature Store on Azure Demo Notebook\n",
    "\n",
    "This notebook illustrates the use of Feature Store to create a model that predicts NYC Taxi fares. It includes these steps:\n",
    "\n",
    "\n",
    "This tutorial demonstrates the key capabilities of Feathr, including:\n",
    "\n",
    "1. Install and set up Feathr with Azure\n",
    "2. Create shareable features with Feathr feature definition configs.\n",
    "3. Create a training dataset via point-in-time feature join.\n",
    "4. Compute and write features.\n",
    "5. Train a model using these features to predict fares.\n",
    "6. Materialize feature value to online store.\n",
    "7. Fetch feature value in real-time from online store for online scoring.\n",
    "\n",
    "In this tutorial, we use Feathr Feature Store to create a model that predicts NYC Taxi fares. The dataset comes from [here](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page). The feature flow is as below:\n",
    "\n",
    "![Feature Flow](https://github.com/linkedin/feathr/blob/main/docs/images/feature_flow.png?raw=true)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisite: Use Quick Start Template to Provision Azure Resources\n",
    "\n",
    "Feathr has native cloud integration. To use Feathr on Azure, you only need three steps:\n",
    "\n",
    "- Get the `Principal ID` of your account by running `az ad signed-in-user show --query objectId -o tsv` in the link below (Select \"Bash\" if asked), and write down that value (something like `b65ef2e0-42b8-44a7-9b55-abbccddeefff`). Think this ID as something representing you when accessing Azure, and it will be used to grant permissions in the next step in the UI.\n",
    "\n",
    "[Launch Cloud Shell](https://shell.azure.com/bash)\n",
    "\n",
    "- Click the button below to deploy a minimal set of Feathr resources for demo purpose. You will need to fill in the `Principal ID` and `Resource Prefix`. You will need \"Owner\" permission of the selected subscription.\n",
    "\n",
    "[![Deploy to Azure](https://aka.ms/deploytoazurebutton)](https://portal.azure.com/#create/Microsoft.Template/uri/https%3A%2F%2Fraw.githubusercontent.com%2Flinkedin%2Ffeathr%2Fmain%2Fdocs%2Fhow-to-guides%2Fazure_resource_provision.json)\n",
    "\n",
    "- Run the cells below.\n",
    "\n",
    "And the architecture is as below. In the above template, we are using Synapse as Spark provider, use Azure Data Lake Gen2 as offline store, and use Redis as online store, Azure Purview (Apache Atlas compatible) as feature reigstry. \n",
    "\n",
    "\n",
    "![Architecture](https://github.com/linkedin/feathr/blob/main/docs/images/architecture.png?raw=true)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisite: Install Feathr \n",
    "\n",
    "Install Feathr using pip:\n",
    "\n",
    "`pip install -U feathr pandavro scikit-learn`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisite: Configure the required environment with Feathr Quick Start Template\n",
    "\n",
    "In the first step (Provision cloud resources), you should have provisioned all the required cloud resources. Run the code below to install Feathr, login to Azure to get the required credentials to access more cloud resources."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**REQUIRED STEP: Fill in the resource prefix when provisioning the resources**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resource_prefix = \"feathr_resource_prefix\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install feathr azure-cli  pandavro scikit-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Login to Azure with a device code (You will see instructions in the output):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! az login --use-device-code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "import tempfile\n",
    "from datetime import datetime, timedelta\n",
    "from math import sqrt\n",
    "\n",
    "import pandas as pd\n",
    "import pandavro as pdx\n",
    "from feathr import FeathrClient\n",
    "from feathr import BOOLEAN, FLOAT, INT32, ValueType\n",
    "from feathr import Feature, DerivedFeature, FeatureAnchor\n",
    "from feathr import BackfillTime, MaterializationSettings\n",
    "from feathr import FeatureQuery, ObservationSettings\n",
    "from feathr import RedisSink\n",
    "from feathr import INPUT_CONTEXT, HdfsSource\n",
    "from feathr import WindowAggTransformation\n",
    "from feathr import TypedKey\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "from azure.identity import DefaultAzureCredential\n",
    "from azure.keyvault.secrets import SecretClient\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get all the required credentials from Azure KeyVault"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get all the required credentials from Azure Key Vault\n",
    "key_vault_name=resource_prefix+\"kv\"\n",
    "synapse_workspace_url=resource_prefix+\"syws\"\n",
    "adls_account=resource_prefix+\"dls\"\n",
    "adls_fs_name=resource_prefix+\"fs\"\n",
    "purview_name=resource_prefix+\"purview\"\n",
    "key_vault_uri = f\"https://{key_vault_name}.vault.azure.net\"\n",
    "credential = DefaultAzureCredential(exclude_interactive_browser_credential=False)\n",
    "client = SecretClient(vault_url=key_vault_uri, credential=credential)\n",
    "secretName = \"FEATHR-ONLINE-STORE-CONN\"\n",
    "retrieved_secret = client.get_secret(secretName).value\n",
    "\n",
    "# Get redis credentials; This is to parse Redis connection string.\n",
    "redis_port=retrieved_secret.split(',')[0].split(\":\")[1]\n",
    "redis_host=retrieved_secret.split(',')[0].split(\":\")[0]\n",
    "redis_password=retrieved_secret.split(',')[1].split(\"password=\",1)[1]\n",
    "redis_ssl=retrieved_secret.split(',')[2].split(\"ssl=\",1)[1]\n",
    "\n",
    "# Set the resource link\n",
    "os.environ['spark_config__azure_synapse__dev_url'] = f'https://{synapse_workspace_url}.dev.azuresynapse.net'\n",
    "os.environ['spark_config__azure_synapse__pool_name'] = 'spark31'\n",
    "os.environ['spark_config__azure_synapse__workspace_dir'] = f'abfss://{adls_fs_name}@{adls_account}.dfs.core.windows.net/feathr_project'\n",
    "os.environ['feature_registry__purview__purview_name'] = f'{purview_name}'\n",
    "os.environ['online_store__redis__host'] = redis_host\n",
    "os.environ['online_store__redis__port'] = redis_port\n",
    "os.environ['online_store__redis__ssl_enabled'] = redis_ssl\n",
    "os.environ['REDIS_PASSWORD']=redis_password\n",
    "os.environ['feature_registry__purview__purview_name'] = f'{purview_name}'\n",
    "feathr_output_path = f'abfss://{adls_fs_name}@{adls_account}.dfs.core.windows.net/feathr_output'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisite: Configure the required environment (Don't need to update if using the above Quick Start Template)\n",
    "\n",
    "In the first step (Provision cloud resources), you should have provisioned all the required cloud resources. If you use Feathr CLI to create a workspace, you should have a folder with a file called `feathr_config.yaml` in it with all the required configurations. Otherwise, update the configuration below.\n",
    "\n",
    "The code below will write this configuration string to a temporary location and load it to Feathr. Please still refer to [feathr_config.yaml](https://github.com/linkedin/feathr/blob/main/feathr_project/feathrcli/data/feathr_user_workspace/feathr_config.yaml) and use that as the source of truth. It should also have more explanations on the meaning of each variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tempfile\n",
    "yaml_config = \"\"\"\n",
    "# Please refer to https://github.com/linkedin/feathr/blob/main/feathr_project/feathrcli/data/feathr_user_workspace/feathr_config.yaml for explanations on the meaning of each field.\n",
    "api_version: 1\n",
    "project_config:\n",
    "  project_name: 'feathr_getting_started'\n",
    "  required_environment_variables:\n",
    "    - 'REDIS_PASSWORD'\n",
    "    - 'AZURE_CLIENT_ID'\n",
    "    - 'AZURE_TENANT_ID'\n",
    "    - 'AZURE_CLIENT_SECRET'\n",
    "offline_store:\n",
    "  adls:\n",
    "    adls_enabled: true\n",
    "  wasb:\n",
    "    wasb_enabled: true\n",
    "  s3:\n",
    "    s3_enabled: false\n",
    "    s3_endpoint: 's3.amazonaws.com'\n",
    "  jdbc:\n",
    "    jdbc_enabled: false\n",
    "    jdbc_database: 'feathrtestdb'\n",
    "    jdbc_table: 'feathrtesttable'\n",
    "  snowflake:\n",
    "    url: \"dqllago-ol19457.snowflakecomputing.com\"\n",
    "    user: \"feathrintegration\"\n",
    "    role: \"ACCOUNTADMIN\"\n",
    "spark_config:\n",
    "  spark_cluster: 'azure_synapse'\n",
    "  spark_result_output_parts: '1'\n",
    "  azure_synapse:\n",
    "    dev_url: 'https://feathrazuretest3synapse.dev.azuresynapse.net'\n",
    "    pool_name: 'spark3'\n",
    "    workspace_dir: 'abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/feathr_getting_started'\n",
    "    executor_size: 'Small'\n",
    "    executor_num: 4\n",
    "    feathr_runtime_location: wasbs://public@azurefeathrstorage.blob.core.windows.net/feathr-assembly-LATEST.jar\n",
    "  databricks:\n",
    "    workspace_instance_url: 'https://adb-2474129336842816.16.azuredatabricks.net'\n",
    "    config_template: {'run_name':'','new_cluster':{'spark_version':'9.1.x-scala2.12','node_type_id':'Standard_D3_v2','num_workers':2,'spark_conf':{}},'libraries':[{'jar':''}],'spark_jar_task':{'main_class_name':'','parameters':['']}}\n",
    "    work_dir: 'dbfs:/feathr_getting_started'\n",
    "    feathr_runtime_location: https://azurefeathrstorage.blob.core.windows.net/public/feathr-assembly-LATEST.jar\n",
    "online_store:\n",
    "  redis:\n",
    "    host: 'feathrazuretest3redis.redis.cache.windows.net'\n",
    "    port: 6380\n",
    "    ssl_enabled: True\n",
    "feature_registry:\n",
    "  purview:\n",
    "    type_system_initialization: true\n",
    "    purview_name: 'feathrazuretest3-purview1'\n",
    "    delimiter: '__'\n",
    "\"\"\"\n",
    "tmp = tempfile.NamedTemporaryFile(mode='w', delete=False)\n",
    "with open(tmp.name, \"w\") as text_file:\n",
    "    text_file.write(yaml_config)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup necessary environment variables (Skip if using the above Quick Start Template)\n",
    "\n",
    "You should setup the environment variables in order to run this sample. More environment variables can be set by referring to [feathr_config.yaml](https://github.com/linkedin/feathr/blob/main/feathr_project/feathrcli/data/feathr_user_workspace/feathr_config.yaml) and use that as the source of truth. It also has more explanations on the meaning of each variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# os.environ['REDIS_PASSWORD'] = ''\n",
    "# os.environ['AZURE_CLIENT_ID'] = ''\n",
    "# os.environ['AZURE_TENANT_ID'] = ''\n",
    "# os.environ['AZURE_CLIENT_SECRET'] = ''\n",
    "\n",
    "# # Optional envs if you are using different runtimes\n",
    "# os.environ['DATABRICKS_WORKSPACE_TOKEN_VALUE'] = ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initialize Feathr Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = FeathrClient(config_path=tmp.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View the data\n",
    "\n",
    "In this tutorial, we use Feathr Feature Store to create a model that predicts NYC Taxi fares. The dataset comes from [here](https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page). The data is as below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pd.read_csv(\"https://azurefeathrstorage.blob.core.windows.net/public/sample_data/green_tripdata_2020-04_with_index.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining Features with Feathr\n",
    "\n",
    "In Feathr, a feature is viewed as a function, mapping from entity id or key, and timestamp to a feature value. For more details on feature definition, please refer to the [Feathr Feature Definition Guide](https://github.com/linkedin/feathr/blob/main/docs/concepts/feature-definition.md)\n",
    "\n",
    "\n",
    "1. The typed key (a.k.a. entity id) identifies the subject of feature, e.g. a user id, 123.\n",
    "2. The feature name is the aspect of the entity that the feature is indicating, e.g. the age of the user.\n",
    "3. The feature value is the actual value of that aspect at a particular time, e.g. the value is 30 at year 2022.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that, in some cases, such as features defined on top of request data, may have no entity key or timestamp.\n",
    "It is merely a function/transformation executing against request data at runtime.\n",
    "For example, the day of week of the request, which is calculated by converting the request UNIX timestamp.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Sources Section with UDFs\n",
    "A feature source is needed for anchored features that describes the raw data in which the feature values are computed from. See the python documentation to get the details on each input column.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession, DataFrame\n",
    "def feathr_udf_day_calc(df: DataFrame) -> DataFrame:\n",
    "    from pyspark.sql.functions import dayofweek, dayofyear, col\n",
    "    df = df.withColumn(\"fare_amount_cents\", col(\"fare_amount\")*100)\n",
    "    return df\n",
    "\n",
    "batch_source = HdfsSource(name=\"nycTaxiBatchSource\",\n",
    "                          path=\"wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv\",\n",
    "                          event_timestamp_column=\"lpep_dropoff_datetime\",\n",
    "                          preprocessing=feathr_udf_day_calc,\n",
    "                          timestamp_format=\"yyyy-MM-dd HH:mm:ss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Anchors and Features\n",
    "A feature is called an anchored feature when the feature is directly extracted from the source data, rather than computed on top of other features. The latter case is called derived feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_trip_distance = Feature(name=\"f_trip_distance\",\n",
    "                          feature_type=FLOAT, transform=\"trip_distance\")\n",
    "f_trip_time_duration = Feature(name=\"f_trip_time_duration\",\n",
    "                               feature_type=INT32,\n",
    "                               transform=\"(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60\")\n",
    "\n",
    "features = [\n",
    "    f_trip_distance,\n",
    "    f_trip_time_duration,\n",
    "    Feature(name=\"f_is_long_trip_distance\",\n",
    "            feature_type=BOOLEAN,\n",
    "            transform=\"cast_float(trip_distance)>30\"),\n",
    "    Feature(name=\"f_day_of_week\",\n",
    "            feature_type=INT32,\n",
    "            transform=\"dayofweek(lpep_dropoff_datetime)\"),\n",
    "]\n",
    "\n",
    "request_anchor = FeatureAnchor(name=\"request_features\",\n",
    "                               source=INPUT_CONTEXT,\n",
    "                               features=features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Window aggregation features\n",
    "\n",
    "For window aggregation features, see the supported fields below:\n",
    "\n",
    "Note that the `agg_func` should be any of these:\n",
    "\n",
    "| Aggregation Type | Input Type | Description |\n",
    "| --- | --- | --- |\n",
    "|SUM, COUNT, MAX, MIN, AVG\t|Numeric|Applies the the numerical operation on the numeric inputs. |\n",
    "|MAX_POOLING, MIN_POOLING, AVG_POOLING\t| Numeric Vector | Applies the max/min/avg operation on a per entry bassis for a given a collection of numbers.|\n",
    "|LATEST| Any |Returns the latest not-null values from within the defined time window |\n",
    "\n",
    "\n",
    "After you have defined features and sources, bring them together to build an anchor:\n",
    "\n",
    "\n",
    "Note that if the data source is from the observation data, the `source` section should be `INPUT_CONTEXT` to indicate the source of those defined anchors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "location_id = TypedKey(key_column=\"DOLocationID\",\n",
    "                       key_column_type=ValueType.INT32,\n",
    "                       description=\"location id in NYC\",\n",
    "                       full_name=\"nyc_taxi.location_id\")\n",
    "agg_features = [Feature(name=\"f_location_avg_fare\",\n",
    "                        key=location_id,\n",
    "                        feature_type=FLOAT,\n",
    "                        transform=WindowAggTransformation(agg_expr=\"cast_float(fare_amount)\",\n",
    "                                                          agg_func=\"AVG\",\n",
    "                                                          window=\"90d\")),\n",
    "                Feature(name=\"f_location_max_fare\",\n",
    "                        key=location_id,\n",
    "                        feature_type=FLOAT,\n",
    "                        transform=WindowAggTransformation(agg_expr=\"cast_float(fare_amount)\",\n",
    "                                                          agg_func=\"MAX\",\n",
    "                                                          window=\"90d\")),\n",
    "                Feature(name=\"f_location_total_fare_cents\",\n",
    "                        key=location_id,\n",
    "                        feature_type=FLOAT,\n",
    "                        transform=WindowAggTransformation(agg_expr=\"fare_amount_cents\",\n",
    "                                                          agg_func=\"SUM\",\n",
    "                                                          window=\"90d\")),\n",
    "                ]\n",
    "\n",
    "agg_anchor = FeatureAnchor(name=\"aggregationFeatures\",\n",
    "                           source=batch_source,\n",
    "                           features=agg_features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Derived Features Section\n",
    "Derived features are the features that are computed from other features. They could be computed from anchored features, or other derived features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_trip_time_distance = DerivedFeature(name=\"f_trip_time_distance\",\n",
    "                                      feature_type=FLOAT,\n",
    "                                      input_features=[\n",
    "                                          f_trip_distance, f_trip_time_duration],\n",
    "                                      transform=\"f_trip_distance * f_trip_time_duration\")\n",
    "\n",
    "f_trip_time_rounded = DerivedFeature(name=\"f_trip_time_rounded\",\n",
    "                                     feature_type=INT32,\n",
    "                                     input_features=[f_trip_time_duration],\n",
    "                                     transform=\"f_trip_time_duration % 10\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And then we need to build those features so that it can be consumed later. Note that we have to build both the \"anchor\" and the \"derived\" features (which is not anchored to a source)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[\n",
    "                      f_trip_time_distance, f_trip_time_rounded])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create training data using point-in-time correct feature join\n",
    "\n",
    "A training dataset usually contains entity id columns, multiple feature columns, event timestamp column and label/target column. \n",
    "\n",
    "To create a training dataset using Feathr, one needs to provide a feature join configuration file to specify\n",
    "what features and how these features should be joined to the observation data. \n",
    "\n",
    "To learn more on this topic, please refer to [Point-in-time Correctness](https://github.com/linkedin/feathr/blob/main/docs/concepts/point-in-time-join.md)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if client.spark_runtime == 'databricks':\n",
    "    output_path = 'dbfs:/feathrazure_test.avro'\n",
    "else:\n",
    "    output_path = feathr_output_path\n",
    "\n",
    "\n",
    "feature_query = FeatureQuery(\n",
    "    feature_list=[\"f_location_avg_fare\", \"f_trip_time_rounded\", \"f_is_long_trip_distance\", \"f_location_total_fare_cents\"], key=location_id)\n",
    "settings = ObservationSettings(\n",
    "    observation_path=\"wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04_with_index.csv\",\n",
    "    event_timestamp_column=\"lpep_dropoff_datetime\",\n",
    "    timestamp_format=\"yyyy-MM-dd HH:mm:ss\")\n",
    "client.get_offline_features(observation_settings=settings,\n",
    "                            feature_query=feature_query,\n",
    "                            output_path=output_path)\n",
    "client.wait_job_to_finish(timeout_sec=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download the result and show the result\n",
    "\n",
    "Let's use the helper function `get_result_df` to download the result and view it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result_df(client: FeathrClient) -> pd.DataFrame:\n",
    "    \"\"\"Download the job result dataset from cloud as a Pandas dataframe.\"\"\"\n",
    "    res_url = client.get_job_result_uri(block=True, timeout_sec=600)\n",
    "    tmp_dir = tempfile.TemporaryDirectory()\n",
    "    client.feathr_spark_laucher.download_result(result_path=res_url, local_folder=tmp_dir.name)\n",
    "    dataframe_list = []\n",
    "    # assuming the result are in avro format\n",
    "    for file in glob.glob(os.path.join(tmp_dir.name, '*.avro')):\n",
    "        dataframe_list.append(pdx.read_avro(file))\n",
    "    vertical_concat_df = pd.concat(dataframe_list, axis=0)\n",
    "    tmp_dir.cleanup()\n",
    "    return vertical_concat_df\n",
    "\n",
    "df_res = get_result_df(client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a machine learning model\n",
    "After getting all the features, let's train a machine learning model with the converted feature by Feathr:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remove columns\n",
    "from sklearn.ensemble import GradientBoostingRegressor\n",
    "final_df = df_res\n",
    "final_df.drop([\"lpep_pickup_datetime\", \"lpep_dropoff_datetime\",\n",
    "              \"store_and_fwd_flag\"], axis=1, inplace=True, errors='ignore')\n",
    "final_df.fillna(0, inplace=True)\n",
    "final_df['fare_amount'] = final_df['fare_amount'].astype(\"float64\")\n",
    "\n",
    "\n",
    "train_x, test_x, train_y, test_y = train_test_split(final_df.drop([\"fare_amount\"], axis=1),\n",
    "                                                    final_df[\"fare_amount\"],\n",
    "                                                    test_size=0.2,\n",
    "                                                    random_state=42)\n",
    "model = GradientBoostingRegressor()\n",
    "model.fit(train_x, train_y)\n",
    "\n",
    "y_predict = model.predict(test_x)\n",
    "\n",
    "y_actual = test_y.values.flatten().tolist()\n",
    "rmse = sqrt(mean_squared_error(y_actual, y_predict))\n",
    "\n",
    "sum_actuals = sum_errors = 0\n",
    "\n",
    "for actual_val, predict_val in zip(y_actual, y_predict):\n",
    "    abs_error = actual_val - predict_val\n",
    "    if abs_error < 0:\n",
    "        abs_error = abs_error * -1\n",
    "\n",
    "    sum_errors = sum_errors + abs_error\n",
    "    sum_actuals = sum_actuals + actual_val\n",
    "\n",
    "mean_abs_percent_error = sum_errors / sum_actuals\n",
    "print(\"Model MAPE:\")\n",
    "print(mean_abs_percent_error)\n",
    "print()\n",
    "print(\"Model Accuracy:\")\n",
    "print(1 - mean_abs_percent_error)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Materialize feature value into offline/online storage\n",
    "\n",
    "While Feathr can compute the feature value from the feature definition on-the-fly at request time, it can also pre-compute\n",
    "and materialize the feature value to offline and/or online storage. \n",
    "\n",
    "We can push the generated features to the online store like below:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "backfill_time = BackfillTime(start=datetime(\n",
    "    2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))\n",
    "redisSink = RedisSink(table_name=\"nycTaxiDemoFeature\")\n",
    "settings = MaterializationSettings(\"nycTaxiTable\",\n",
    "                                   backfill_time=backfill_time,\n",
    "                                   sinks=[redisSink],\n",
    "                                   feature_names=[\"f_location_avg_fare\", \"f_location_max_fare\"])\n",
    "\n",
    "client.materialize_features(settings)\n",
    "client.wait_job_to_finish(timeout_sec=500)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then get the features from the online store (Redis):\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fetching feature value for online inference\n",
    "\n",
    "For features that are already materialized by the previous step, their latest value can be queried via the client's\n",
    "`get_online_features` or `multi_get_online_features` API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = client.get_online_features('nycTaxiDemoFeature', '265', [\n",
    "                                 'f_location_avg_fare', 'f_location_max_fare'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client.multi_get_online_features(\"nycTaxiDemoFeature\", [\"239\", \"265\"], [\n",
    "                                 'f_location_avg_fare', 'f_location_max_fare'])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Registering and Fetching features\n",
    "\n",
    "We can also register the features with an Apache Atlas compatible service, such as Azure Purview, and share the registered features across teams:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client.register_features()\n",
    "client.list_registered_features(project_name=\"feathr_getting_started\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "830c16c5b424e7ff512f67d4056b67cea1a756a7ad6a92c98b9e2b95c5e484ae"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
